# Copyright 2023–2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#    https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""CLI utility for running inference on a single/multi stream(s)."""

import os
from typing import Sequence
import jax
import jax.numpy as jnp

from absl import app

from jetstream.engine import engine_api

from MaxText import max_utils
from MaxText import maxengine
from MaxText import pyconfig
from MaxText import profiler
from MaxText import multimodal_utils
# Placeholder: internal

# Number of text sequences to process in a single batch.
_NUM_STREAMS = 1


def _batch_first_result_token(first_tokens: list[engine_api.ResultTokens], batch_size: int):
  """Batches together a list of first result tokens from prefill calls.

  This is needed because prefill currently returns the first token as a batch of size 1
  to optimize latency to first token without padding to the configured batch size, while
  generate returns a batch of configured size. This function batches a list of
  such single-element first tokens into one batch to simulate the normal processing
  that first tokens are generated by generate.

  Args:
    first_tokens: A list of `ResultTokens` representing first token returned by `prefill`
    batch_size: The target batch size to pad to. This should be from the config.
  Return:
    A `ResultTokens` with all first tokens batched as if they are produced by a single
    `generate` step.
  """
  data = jnp.vstack([first_token.data for first_token in first_tokens])

  def _pad_to_batch_size(data: jax.Array, batch_size: int):
    pad_width = [(0, batch_size - data.shape[0]), (0, 0)]
    data = jnp.pad(data, pad_width, mode="constant", constant_values=0)
    return data

  result_tokens = engine_api.ResultTokens(
      data=_pad_to_batch_size(data, batch_size),
      tokens_idx=(0, 1),
      valid_idx=(1, 2),
      length_idx=(2, 3),
      samples_per_slot=1,
  )

  def _all_equals(elements: Sequence[jax.Array], target: jax.Array):
    """Checks if each element equals the given target."""
    stacked = jnp.stack(elements)
    row_comparisons = stacked == target
    return jnp.all(row_comparisons)

  # `tokens_idx`, `valid_idx`, `length_idx` and `samples_per_slot` are hardcoded
  # and should be the same for all first tokens returned from prefill.
  assert _all_equals([jnp.array(t.tokens_idx) for t in first_tokens], jnp.array(result_tokens.tokens_idx))
  assert _all_equals([jnp.array(t.valid_idx) for t in first_tokens], jnp.array(result_tokens.valid_idx))
  assert _all_equals([jnp.array(t.length_idx) for t in first_tokens], jnp.array(result_tokens.length_idx))
  assert _all_equals([jnp.array(t.samples_per_slot) for t in first_tokens], jnp.array(result_tokens.samples_per_slot))

  return result_tokens


def main(argv: Sequence[str]) -> None:
  jax.config.update("jax_default_prng_impl", "unsafe_rbg")
  os.environ["TF_CPP_MIN_LOG_LEVEL"] = "0"

  config = pyconfig.initialize(argv)
  _validate_config(config)
  jax.config.update("jax_use_shardy_partitioner", config.shardy)
  max_utils.print_system_information()

  engine = maxengine.MaxEngine(config)
  rng = jax.random.PRNGKey(1234)
  rng, rng_load_params = jax.random.split(rng)
  params = engine.load_params(rng_load_params)
  prof = profiler.Profiler(config)

  text = config.prompt
  prefill_length = config.max_prefill_predict_length
  processor_outputs = multimodal_utils.PreprocessorOutput()
  if config.use_multimodal:
    image_path = config.image_path.split(",")
    images = [multimodal_utils.load_image_from_path(p) for p in image_path]
    processor_outputs = multimodal_utils.pre_process_image(images, model_name=config.model_name)
    image_offsets = multimodal_utils.get_image_offsets(config.model_name, processor_output=processor_outputs)

    prefill_length -= image_offsets
    text = multimodal_utils.reformat_prompt(
        text, image_placeholder=config.image_placeholder, model_name=config.model_name, num_images=len(images)
    )

  metadata = engine.get_tokenizer()
  tokenizer_model = engine.build_tokenizer(metadata)
  try:
    # TODO: update jetstream.engine.tokenizer_api.Tokenizer to maintain tokenizer state.
    has_chat_template = getattr(tokenizer_model.tokenizer, "chat_template", False)  # pytype: disable=attribute-error
  except AttributeError as _:
    has_chat_template = False
  tokens, true_length = tokenizer_model.encode(text, is_bos=not has_chat_template, prefill_lengths=[prefill_length])
  if config.use_multimodal:
    tokens = multimodal_utils.prepare_text_for_image_fusion(
        tokens, model_name=config.model_name, processor_output=processor_outputs
    )
    true_length += image_offsets

  assert (
      true_length <= config.max_prefill_predict_length
  ), f"Input token length {true_length} is longer than {config.max_prefill_predict_length=}"
  assert config.quantization != "fp8", "fp8 on NVIDIA GPUs is not supported in decode.py yet"
  assert config.quantization != "nanoo_fp8", "NANOO fp8 on AMD MI300/MI325 GPUs is not supported in decode.py yet"

  batch_size = int(config.per_device_batch_size * jax.device_count())
  assert (
      0 < _NUM_STREAMS <= batch_size
  ), f"The number of streams {_NUM_STREAMS} must be > 0 and <= batch size {batch_size}"

  prefill_result_list = []
  first_token_list = []
  sampled_tokens_list = []

  prof.activate(optional_postfix="trace")

  # Prefill
  rng, rng_prefill = jax.random.split(rng)  # Split RNG before calling prefill
  for i in range(_NUM_STREAMS):
    with jax.profiler.StepTraceAnnotation("prefill", stream=i):
      prefill_result, first_token = engine.prefill(
          params=params,
          padded_tokens=tokens,
          images=processor_outputs.pixel_values if config.use_multimodal else None,
          image_masks=processor_outputs.pixel_mask if config.use_multimodal and "llama4" in config.model_name else None,
          true_length=true_length,
          rng=rng_prefill,
          slot=i,
      )
    prefill_result_list.append(prefill_result)
    first_token_list.append(first_token)

  # Insert
  rng, rng_init_decode = jax.random.split(rng)
  decode_state = engine.init_decode_state(rng_init_decode)
  for i in range(_NUM_STREAMS):
    decode_state = engine.insert(prefill_result_list[i], decode_state, slot=i)

  # Generate
  prof_deactivated = False
  steps = range(config.max_prefill_predict_length, config.max_target_length)
  sampled_tokens_list.append(_batch_first_result_token(first_token_list, batch_size))
  for i in steps:
    rng, rng_generate = jax.random.split(rng)
    with jax.profiler.StepTraceAnnotation("generate", step=i):
      decode_state, sampled_tokens = engine.generate(params, decode_state, rng=rng_generate)

    # Automatically deactivate profiler after profiler_steps steps
    if i > config.max_prefill_predict_length + config.profiler_steps:
      prof.deactivate(blocking_object=sampled_tokens)
      prof_deactivated = True

    sampled_tokens_list.append(sampled_tokens)

  # Get results
  for i in range(_NUM_STREAMS):
    results = [t.get_result_at_slot(i).tokens.item() for t in sampled_tokens_list]
    output = tokenizer_model.decode(results)
    print(f"Input `{text}` -> `{output}`")

  assert output.startswith(
      config.autoregressive_decode_assert
  ), f"generated text mismatch {output=}, {config.autoregressive_decode_assert=}"

  # Deactivate profiler
  if not prof_deactivated:
    prof.deactivate(blocking_object=output)

  prof.post_process()


def _validate_config(config):
  assert config.load_full_state_path == "", (
      "Decode doesn't operate on full states! Convert to parameter checkpoint first."
      "Using generate_param_only_checkpoint."
  )


if __name__ == "__main__":
  app.run(main)
